package com.at.bigdata.spark.core.rdd.part

import org.apache.spark.{HashPartitioner, Partitioner, SparkConf, SparkContext}

/**
 *
 * @author cdhuangchao3
 * @date 2023/3/6 8:40 PM
 */
object Spark01_RDD_Part {

  def main(args: Array[String]): Unit = {
    // TODO 建立和spark框架的连接
    val sparConf = new SparkConf()
      .setMaster("local")
      .setAppName("WordCount")
    val sc = new SparkContext(sparConf)

    val rdd = sc.makeRDD(List(
      ("nba", "xxxxxx"),
      ("cba", "xxxxxx"),
      ("wnba", "xxxxxx"),
      ("nba", "xxxxxx")
    ), 3)

    val partRDD = rdd.partitionBy(new MyPartitioner)

    partRDD.saveAsTextFile("output")

    sc.stop()
  }

  class MyPartitioner extends Partitioner {
    // 分区数量
    override def numPartitions: Int = 3

    // 根据数据的key 返回数据所在的分区索引（从0开始）
    override def getPartition(key: Any): Int = {
      key match {
        case "nba" => 0
        case "wnba" => 1
        case _ => 2
      }
    }
  }
}
